
include("../src/qmcmary.jl")
using ..qmcmary

using LinearAlgebra
using ForwardDiff


"""
mu的导数
"""
function bbar_mu()
    #创建ss
    Nx = 2
    #ehk = rand(ComplexF64, Nx, Nx)
    #ehk = Hermitian(ehk+adjoint(ehk))
    hk = lattice_chain(Nx, -1.0+0.0im)
    Ui = ones(Nx)
    Nt = 20
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    #
    ehk, Ui = unpack_splitting(sp, 1)
    mus = -1.8*ones(2*Nx)
    bmatr = bmat_IsingAD(ehk, hscfg[1, :], Ui,
    zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[1]*Diagonal(exp.(-sp.dt*mus))
    fbmatr = (mu) -> bmat_IsingAD(ehk, hscfg[1, :], Ui,
    zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[1]*Diagonal(exp.(-sp.dt*mu))
    grad = ForwardDiff.jacobian(fbmatr, mus)
    #println(bmatr)
    #println(size(grad))
    mu2 = copy(mus)
    mu2[1] += 0.01
    bmatr2 = bmat_IsingAD(ehk, hscfg[1, :], Ui,
    zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[1]*Diagonal(exp.(-sp.dt*mu2))
    println(reshape(grad[:, 1], 4, 4))
    println(bmatr2[1, :] - bmatr[1, :])
end


"""
U的导数
"""
function bbar_U()
    #创建ss
    Nx = 2
    #ehk = rand(ComplexF64, Nx, Nx)
    #ehk = Hermitian(ehk+adjoint(ehk))
    hk = lattice_chain(Nx, -1.0+0.0im)
    Ui = 0.1*ones(Nx)
    Nt = 20
    Ng = 5
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    #
    ehk, Ui = unpack_splitting(sp, 1)
    bmatr = bmat_IsingAD(ehk, hscfg[1, :], Ui, zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[1]
    fbmatr = (Ua) -> bmat_IsingAD(ehk, hscfg[1, :], Ua, zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[1]
    println(Ui)
    grad = ForwardDiff.jacobian(fbmatr, Ui)
    println(size(grad))
    #println(bmatr)
    #println(size(grad))
    Ui2 = copy(Ui)
    Ui2[1] += 0.01
    bmatr2 = bmat_IsingAD(ehk, hscfg[1, :], Ui2, zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[1]
    println(reshape(grad[:, 1], 4, 4))
    println(bmatr2[1, :] - bmatr[1, :])
    #对比
    tapemap = pbpU_calc_tapemap(sp)
    for bi in Base.OneTo(Nt)
        ehk, Ui = unpack_splitting(sp, bi)
        fbmatr = (Ua) -> bmat_IsingAD(ehk, hscfg[bi, :], Ua, zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[1]
        fbmati = (Ua) -> bmat_IsingAD(ehk, hscfg[bi, :], Ua, zeros(Nx), zeros(Nx), ones(Nx), sp.dt)[2]
        gradr = ForwardDiff.jacobian(fbmatr, Ui)
        gradi = ForwardDiff.jacobian(fbmati, Ui)
        gradr = reshape(gradr, 2*Nx, 2*Nx, Nx)
        gradi = reshape(gradi, 2*Nx, 2*Nx, Nx)
        #
        for xi in Base.OneTo(Nx)
            bmatr, bmati = bmat_IsingAD(ehk, hscfg[bi, :], Ui, zeros(Nx), zeros(Nx), ones(Nx), sp.dt)
            Ui2 = copy(Ui)
            Ui2[xi] += 0.01
            bmatr2, bmati2 = bmat_IsingAD(ehk, hscfg[bi, :], Ui2, zeros(Nx), zeros(Nx), ones(Nx), sp.dt)
            dimatr = bmatr2 - bmatr
            println(dimatr, gradr[:, :, xi]*0.01)
            println(abs.(dimatr - gradr[:, :, xi]*0.01))
            @assert all(isapprox.(dimatr, gradr[:, :, xi]*0.01, atol=1e-3))
            dimati = bmati2 - bmati
            #println(dimati, gradi[:, :, xi]*0.01)
            @assert all(isapprox.(dimati, gradi[:, :, xi]*0.01, atol=1e-3))
            #
            pbrpU2, pbipU2 = pbpU_calc_xi(sp, hscfg[:, xi], xi)
            @assert all(isapprox.(gradr[[xi,xi+Nx],:,xi], pbrpU2[bi, :, :]))
            @assert all(isapprox.(gradi[[xi,xi+Nx],:,xi], pbipU2[bi, :, :]))
            #
            pbrpU3, pbipU3 = pbpU_calc_xi(sp, hscfg[:, xi], xi; tapemap=tapemap)
            @assert all(isapprox.(pbrpU3[bi, :, :], pbrpU2[bi, :, :]))
            @assert all(isapprox.(pbipU3[bi, :, :], pbipU2[bi, :, :]))
        end
    end
end


function bbar_U2()
    #创建ss
    Nx = 2
    #ehk = rand(ComplexF64, Nx, Nx)
    #ehk = Hermitian(ehk+adjoint(ehk))
    hk = lattice_chain(Nx, -1.0+0.0im)
    Ui = 0*ones(Nx)
    Nt = 20
    Ng = 5
    nx = zeros(Nx)
    ny = zeros(Nx)
    nz = ones(Nx)
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    bbarad = bbar_calc(ss, allbmats)
    #
    sgn1 = det(I(4) + prod(reverse(allbmats)))
    lasgn1 = log(abs(sgn1))
    for bi in Base.OneTo(Nt)
        for x1 in Base.OneTo(2Nx); for x2 in Base.OneTo(2Nx)
            allbmats[bi][x1, x2] += 0.01
            sgn2 = det(I(4) + prod(reverse(allbmats)))
            lasgn2 = log(abs(sgn2))
            ndlasgn = -lasgn2 + lasgn1
            #println(ndlasgn, bbarad[bi, x1, x2]*0.01)
            @assert isapprox(ndlasgn, bbarad[bi, x1, x2]*0.01, atol=1e-4)
            #println(-log(abs(sgn2))+log(abs(sgn1)))
            allbmats[bi][x1, x2] -= 0.01
        end; end
    end
    #
    #println(bbarad[1, 1, 1])
end

#bbar_mu()

bbar_U()
bbar_U2()
